package code

import (
	"bytes"
	"fmt"
	"gitee.com/tjloved/tool/dbext"
	"gitee.com/tjloved/tool/helper"
	"go/format"
	"os"
	"strings"
	"text/template"
)

type Translate func(q string, from string, to string) string

type DbGetter interface {
	GetTables(TableNames []string) []dbext.Table
	GetColumns(tableName string) []dbext.Column
}

type GormModelGenerator struct {
	DbGetter DbGetter
}

func NewGormModelGenerator(connection string) *GormModelGenerator {
	generator := &GormModelGenerator{
		DbGetter: &dbext.TableInfoGetter{ConnectionName: connection},
	}
	return generator
}
func (rec *GormModelGenerator) GetModelDir() string {
	return "model"
}

func (rec *GormModelGenerator) Gen(force bool, tableNames ...string) {
	tables := rec.DbGetter.GetTables(tableNames)
	for _, table := range tables {
		fileName := fmt.Sprintf("%s%c%s.go", rec.GetModelDir(), os.PathSeparator, table.Name)
		if !force {
			if _, err := os.Stat(fileName); err == nil {
				fmt.Println("文件已存在:", fileName)
				continue
			}
		}
		stru := Struct{
			Name:      helper.StringToCamelCase(table.Name),
			ShortName: "this",
			Comment:   table.Comment,
		}

		var cols []Column
		var functions []Function
		var imports Imports
		var constants []ConstField

		tableNameFunction := Function{
			Name:     "TableName",
			Receiver: stru,
			Body:     []string{fmt.Sprintf("return \"%s\"", table.Name)},
			Return: []Param{{
				Name:      "string",
				ShortName: "",
			}},
			Params: nil,
		}
		functions = append(functions, tableNameFunction)
		beforeFunction := Function{
			Name:     "AfterFind",
			Receiver: stru,
			Body:     nil,
			Return: []Param{
				{Name: "error"},
			},
			Params: []Param{
				{Name: "*gorm.DB", ShortName: "db"},
			},
		}
		imports.AppendUnique(Import{Name: "\"gorm.io/gorm\""})

		columns := rec.DbGetter.GetColumns(table.Name)
		for _, column := range columns {
			c := Column{
				Name:     column.GetName(),
				JsonName: column.Field,
				Type:     column.GetType(),
				Comment:  column.GetComment(),
				Tag:      "`" + column.GetGormTag() + "`",
			}
			constFields := c.ToEnum()
			if len(constFields) > 0 {
				constants = append(constants, constFields...)
			}
			cols = append(cols, c)

			switch true {
			case column.GetType() == "int8":
				fallthrough
			case column.GetType() == "uint8":
				name := "Set" + column.GetName() + "String"
				f := Function{
					Name:     name,
					Receiver: stru,
				}

				if len(constFields) > 0 {
					f.Body = append(f.Body, fmt.Sprintf("switch this.%s{", column.GetName()))
					for _, constField := range constFields {
						f.Body = append(f.Body, fmt.Sprintf("case %s%s:", stru.Name, constField.Name))
						f.Body = append(f.Body, fmt.Sprintf("this.%s%s=\"%s\"", column.GetName(), "String", constField.Comment))
					}
					f.Body = append(f.Body, "}")
				}
				functions = append(functions, f)
				cols = append(cols, Column{
					Name:     column.GetName() + "String",
					JsonName: column.Field,
					Type:     "string",
					Comment:  column.GetComment(),
					Tag:      "`gorm:\"-\"`",
				})
				beforeFunction.Body = append(beforeFunction.Body, fmt.Sprintf("%s.%s()", stru.ShortName, name))
			case column.GetType() == "decimal.Decimal":
				imports.AppendUnique(Import{Name: "\"github.com/shopspring/decimal\""})
			case strings.HasSuffix(column.Field, "_at"):
				imports.AppendUnique(Import{Name: "\"github.com/golang-module/carbon\""})
				name := "Set" + column.GetName() + "String"
				f := Function{
					Name:     name,
					Receiver: stru,
				}

				f.Body = append(f.Body, fmt.Sprintf("this.%s = carbon.CreateFromTimestamp(this.%s).Format(\"Y-m-d H:i:s\")", column.GetName()+"String", column.GetName()))
				functions = append(functions, f)
				cols = append(cols, Column{
					Name:     column.GetName() + "String",
					JsonName: column.Field,
					Type:     "string",
					Comment:  column.GetComment(),
					Tag:      "`gorm:\"-\"`",
				})
				beforeFunction.Body = append(beforeFunction.Body, fmt.Sprintf("%s.%s()", stru.ShortName, name))

			}
		}
		beforeFunction.Body = append(beforeFunction.Body, "return nil")
		functions = append([]Function{beforeFunction}, functions...)

		tmpl := StructTemplate
		t, err := template.New("").Parse(tmpl)
		if err != nil {
			panic(err)
		}

		data := map[string]any{
			"PackageName":    "model",
			"Struct":         stru,
			"Columns":        cols,
			"Functions":      functions,
			"BeforeFunction": beforeFunction,
			"Imports":        imports,
			"Constants":      constants,
		}

		//file, err := os.Create("models/" + table.Name + ".go")
		file, err := os.Create(fileName)
		defer func() {
			_ = file.Close()
		}()

		var buffer bytes.Buffer
		err = t.Execute(&buffer, data)

		if err != nil {
			panic(err)
		}
		formattedCode, err := format.Source(buffer.Bytes())
		if err != nil {
			fmt.Println("格式化错误:", err)
			return
		}
		fmt.Println(string(formattedCode))
		if _, err := file.Write(formattedCode); err != nil {
			fmt.Println("写入失败:", err)
		}
	}
}
